import torch
import torch.nn as nn


class PlasticityModel(nn.Module):

    def __init__(self, yield_stress: float = 0.1, hardening: float = 0.0):
        """
        Define trainable continuous physical parameters for differentiable optimization.
        Initialize yield stress and isotropic hardening parameters.

        Args:
            yield_stress (float): yield stress threshold for plastic correction.
            hardening (float): isotropic hardening parameter.
        """
        super().__init__()
        self.yield_stress = nn.Parameter(torch.tensor(yield_stress))  # scalar parameter
        self.hardening = nn.Parameter(torch.tensor(hardening))        # scalar parameter

    def forward(self, F: torch.Tensor) -> torch.Tensor:
        """
        Compute corrected deformation gradient using von Mises plasticity return mapping.

        Args:
            F (torch.Tensor): deformation gradient tensor (B, 3, 3).

        Returns:
            F_corrected (torch.Tensor): corrected deformation gradient tensor (B, 3, 3).
        """
        B = F.shape[0]

        # SVD of deformation gradient: F = U * diag(sigma) * Vh
        U, sigma, Vh = torch.linalg.svd(F)  # U,Vh: (B,3,3), sigma: (B,3)

        # Clamp singular values to avoid log(0)
        sigma_clamped = torch.clamp_min(sigma, 1e-5)  # (B, 3)

        # Compute logarithmic strain
        epsilon = torch.log(sigma_clamped)  # (B, 3)

        # Deviatoric strain: subtract mean (volumetric) strain
        epsilon_mean = epsilon.mean(dim=1, keepdim=True)  # (B, 1)
        epsilon_dev = epsilon - epsilon_mean  # (B, 3)

        # Norm of deviatoric strain
        epsilon_dev_norm = torch.norm(epsilon_dev, dim=1, keepdim=True)  # (B, 1)

        # Effective yield threshold with hardening, clamped to positive
        yield_threshold = torch.clamp_min(self.yield_stress + self.hardening, 1e-8)  # scalar

        # Plastic correction factor (return mapping)
        gamma = torch.clamp_min(epsilon_dev_norm - yield_threshold, 0.0) / (epsilon_dev_norm + 1e-12)  # (B,1)

        # Correct deviatoric strain
        epsilon_dev_corrected = epsilon_dev * (1 - gamma)  # (B, 3)

        # Reconstruct corrected logarithmic strain
        epsilon_corrected = epsilon_dev_corrected + epsilon_mean  # (B, 3)

        # Exponentiate to get corrected singular values
        sigma_corrected = torch.exp(epsilon_corrected)  # (B, 3)

        # Recompose corrected deformation gradient
        F_corrected = torch.matmul(U, torch.matmul(torch.diag_embed(sigma_corrected), Vh))  # (B, 3, 3)

        return F_corrected


class ElasticityModel(nn.Module):

    def __init__(self, youngs_modulus_log: float = 11.49, poissons_ratio_sigmoid: float = 1.00):
        """
        Define trainable continuous physical parameters for differentiable optimization.
        Initialize with previous best values.

        Args:
            youngs_modulus_log (float): log of Young's modulus.
            poissons_ratio_sigmoid (float): Poisson's ratio before sigmoid transformation.
        """
        super().__init__()
        self.youngs_modulus_log = nn.Parameter(torch.tensor(youngs_modulus_log))  # scalar
        self.poissons_ratio_sigmoid = nn.Parameter(torch.tensor(poissons_ratio_sigmoid))  # scalar

    def forward(self, F: torch.Tensor) -> torch.Tensor:
        """
        Compute Kirchhoff stress tensor using Corotated elasticity model.

        Args:
            F (torch.Tensor): deformation gradient tensor (B, 3, 3).

        Returns:
            kirchhoff_stress (torch.Tensor): Kirchhoff stress tensor (B, 3, 3).
        """
        B = F.size(0)

        # Recover physical parameters
        youngs_modulus = self.youngs_modulus_log.exp()  # scalar positive
        poissons_ratio = self.poissons_ratio_sigmoid.sigmoid() * 0.49  # scalar in [0, 0.49]

        # Compute Lamé parameters
        mu = youngs_modulus / (2 * (1 + poissons_ratio))  # (scalar)
        la = youngs_modulus * poissons_ratio / ((1 + poissons_ratio) * (1 - 2 * poissons_ratio))  # (scalar)

        # SVD of F
        U, sigma, Vh = torch.linalg.svd(F)  # (B,3,3), (B,3), (B,3,3)
        sigma = torch.clamp_min(sigma, 1e-5)  # avoid zero singular values

        # Rotation matrix R = U * Vh
        R = torch.matmul(U, Vh)  # (B, 3, 3)

        # Determinant J = product of singular values
        J = torch.prod(sigma, dim=1).view(-1, 1, 1)  # (B, 1, 1)

        # Identity matrix I
        I = torch.eye(3, dtype=F.dtype, device=F.device).unsqueeze(0).expand(B, -1, -1)  # (B, 3, 3)

        # Corotated first Piola-Kirchhoff stress: P_corot = 2 * mu * (F - R)
        mu_expanded = mu.view(-1, 1, 1)  # (B, 1, 1)
        P_corot = 2 * mu_expanded * (F - R)  # (B, 3, 3)

        # Volume part: P_vol = la * J * (J - 1) * J * F^{-T}
        F_inv = torch.linalg.inv(F)  # (B, 3, 3)
        F_inv_T = F_inv.transpose(1, 2)  # (B, 3, 3)
        volume_factor = la.view(-1, 1, 1) * J * (J - 1).view(-1, 1, 1)  # (B, 1, 1)
        P_vol = volume_factor * J * F_inv_T  # (B, 3, 3)

        # Total first Piola-Kirchhoff stress tensor
        P = P_corot + P_vol  # (B, 3, 3)

        # Kirchhoff stress tensor tau = P @ F^T
        Ft = F.transpose(1, 2)  # (B, 3, 3)
        kirchhoff_stress = torch.matmul(P, Ft)  # (B, 3, 3)

        return kirchhoff_stress
